/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <thrift/lib/cpp/protocol/TBase64Utils.h>

using std::string;

namespace apache::thrift::protocol {

static const uint8_t* kBase64EncodeTable =
    (const uint8_t*)"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

void base64_encode(const uint8_t* in, uint32_t len, uint8_t* buf) {
  buf[0] = kBase64EncodeTable[(in[0] >> 2) & 0x3f];
  if (len == 3) {
    buf[1] = kBase64EncodeTable[((in[0] << 4) & 0x30) | ((in[1] >> 4) & 0x0f)];
    buf[2] = kBase64EncodeTable[((in[1] << 2) & 0x3c) | ((in[2] >> 6) & 0x03)];
    buf[3] = kBase64EncodeTable[in[2] & 0x3f];
  } else if (len == 2) {
    buf[1] = kBase64EncodeTable[((in[0] << 4) & 0x30) | ((in[1] >> 4) & 0x0f)];
    buf[2] = kBase64EncodeTable[(in[1] << 2) & 0x3c];
  } else { // len == 1
    buf[1] = kBase64EncodeTable[(in[0] << 4) & 0x30];
  }
}

static const uint8_t kBase64DecodeTable[256] = {
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x3e, 0xff, 0xff, 0xff, 0x3f,
    0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
    0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12,
    0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24,
    0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
    0x31, 0x32, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    0xff, 0xff, 0xff, 0xff,
};

namespace {
void base64_decode(const uint8_t* in, uint32_t len, uint8_t* buf) {
  buf[0] = (kBase64DecodeTable[in[0]] << 2) | (kBase64DecodeTable[in[1]] >> 4);
  if (len > 2) {
    buf[1] = ((kBase64DecodeTable[in[1]] << 4) & 0xf0) |
        (kBase64DecodeTable[in[2]] >> 2);
    if (len > 3) {
      buf[2] = ((kBase64DecodeTable[in[2]] << 6) & 0xc0) |
          (kBase64DecodeTable[in[3]]);
    }
  }
}
} // namespace

void base64_decode(uint8_t* buf, uint32_t len) {
  return base64_decode(buf, len, buf);
}

std::string base64Encode(folly::ByteRange binary) {
  std::string base64((binary.size() + 2) / 3 * 4, '=');
  for (size_t idx = 0; idx < binary.size(); idx += 3) {
    auto in = binary.begin() + idx;
    auto out = base64.begin() + idx / 3 * 4;
    auto inLen = std::min(static_cast<int>(binary.end() - in), 3);
    base64_encode(in, inLen, reinterpret_cast<uint8_t*>(&*out));
  }
  return base64;
}

std::unique_ptr<folly::IOBuf> base64Decode(folly::StringPiece base64) {
  while (!base64.empty() && base64.back() == '=') {
    base64.pop_back();
  }
  auto binary = folly::IOBuf::create(base64.size() * 3 / 4);
  // Valid base64-encoded strings have unpadded length equal to 4k+{0,2,3} for
  // some k. Break out of the loop if 0 or 1 bytes remain to avoid buffer
  // overrun on invalid input.
  for (size_t idx = 0; idx + 1 < base64.size(); idx += 4) {
    auto in = base64.begin() + idx;
    auto out = binary->writableTail();
    auto inLen = std::min(static_cast<int>(base64.end() - in), 4);
    base64_decode(reinterpret_cast<const uint8_t*>(&*in), inLen, out);
    binary->append(inLen * 3 / 4);
  }
  return binary;
}

} // namespace apache::thrift::protocol
